{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tutorial: Online RL for Multi-Hop Research\n",
    "\n",
    "WARNING: This feature is new and extremely EXPERIMENTAL. Unlike almost everything else in DSPy, it's currently in pure proof of concept and development mode, but we release it to encourage community involvement.\n",
    "\n",
    "For this tutorial, you will also need [DSPy's Arbor RL framework](https://github.com/Ziems/arbor) which you can install with:\n",
    "\n",
    "```bash\n",
    "> pip install -U arbor-ai\n",
    "```\n",
    "\n",
    "You may also have to install DSPy from the main branch:\n",
    "```bash\n",
    "> pip install -U git+https://github.com/stanfordnlp/dspy.git@main\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import dspy\n",
    "import arbor\n",
    "from arbor import ArborGRPO, ArborProvider\n",
    "arbor_server_info = arbor.init() # Initialize the Arbor server in the background\n",
    "\n",
    "port = 7453\n",
    "local_lm_name = \"Qwen/Qwen2.5-1.5B-Instruct\"\n",
    "local_lm = dspy.LM(\n",
    "    model=f\"openai/arbor:{local_lm_name}\",\n",
    "    provider=ArborProvider(),\n",
    "    api_base=arbor_server_info[\"base_url\"],\n",
    "    # Arbor checks to make sure these match the training config\n",
    "    temperature=1.0,\n",
    "    top_p=1.0,\n",
    "    top_k=-1,\n",
    "    repetition_penalty=1.0,\n",
    "    max_tokens=2048,\n",
    ")\n",
    "\n",
    "dspy.configure(lm=local_lm)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Install dependencies and download data\n",
    "\n",
    "To do the retrieval, we'll use the cool BM25S library, as it's pretty lightweight. You can replace this components with whatever you like.\n",
    "\n",
    "```shell\n",
    "> pip install -U bm25s PyStemmer \"jax[cpu]\"\n",
    "```\n",
    "\n",
    "Next, we'll download a snapshot abstracts (i.e., first paragraphs) of all 5,000,000 Wikipedia pages as of 2017. We'll use this as our retrieval corpus.\n",
    "\n",
    "This is 500MB compressed, so the download and decompression may take 2-3 minutes.\n",
    "\n",
    "```python\n",
    "from dspy.utils import download\n",
    "\n",
    "download(\"https://huggingface.co/dspy/cache/resolve/main/wiki.abstracts.2017.tar.gz\")\n",
    "!tar -xzvf wiki.abstracts.2017.tar.gz\n",
    "```\n",
    "\n",
    "And then let's index it for BM25 retrieval! This will take 2-3 minutes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import ujson\n",
    "import bm25s\n",
    "import Stemmer\n",
    "\n",
    "corpus = []\n",
    "\n",
    "with open(\"wiki.abstracts.2017.jsonl\") as f:\n",
    "    for line in f:\n",
    "        line = ujson.loads(line)\n",
    "        corpus.append(f\"{line['title']} | {' '.join(line['text'])}\")\n",
    "\n",
    "stemmer = Stemmer.Stemmer(\"english\")\n",
    "corpus_tokens = bm25s.tokenize(corpus, stopwords=\"en\", stemmer=stemmer)\n",
    "\n",
    "retriever = bm25s.BM25(k1=0.9, b=0.4)\n",
    "retriever.index(corpus_tokens)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load the HoVer dataset.\n",
    "\n",
    "Let's load a dataset for our task. We'll load examples from the HoVer multi-hop task, where the input is a (really!) complex claim and the output we're seeking is the set of Wikipedia pages that are required to fact-check that claim.\n",
    "\n",
    "You may have to install an older version of the dataset to get it working properly...\n",
    "```shell\n",
    "> pip install datasets==3.6.0\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "from dspy.datasets import DataLoader\n",
    "\n",
    "kwargs = dict(fields=(\"claim\", \"supporting_facts\", \"hpqa_id\", \"num_hops\"), input_keys=(\"claim\",))\n",
    "hover = DataLoader().from_huggingface(dataset_name=\"hover-nlp/hover\", split=\"train\", trust_remote_code=True, **kwargs)\n",
    "\n",
    "hpqa_ids = set()\n",
    "hover = [\n",
    "    dspy.Example(claim=x.claim, titles=list(set([y[\"key\"] for y in x.supporting_facts]))).with_inputs(\"claim\")\n",
    "    for x in hover\n",
    "    if x[\"num_hops\"] == 3 and x[\"hpqa_id\"] not in hpqa_ids and not hpqa_ids.add(x[\"hpqa_id\"])\n",
    "]\n",
    "\n",
    "random.Random(0).shuffle(hover)\n",
    "trainset, devset, testset = hover[:600], hover[600:900], hover[900:]\n",
    "len(trainset), len(devset), len(testset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, let's define a function to do the search in Wikipedia. This will use our BM25 index."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def search(query: str, k: int) -> list[str]:\n",
    "    tokens = bm25s.tokenize(query, stopwords=\"en\", stemmer=stemmer, show_progress=False)\n",
    "    results, scores = retriever.retrieve(tokens, k=k, n_threads=1, show_progress=False)\n",
    "    run = {corpus[doc]: float(score) for doc, score in zip(results[0], scores[0])}\n",
    "    return list(run.keys())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## A DSPy program for multi-hop research\n",
    "\n",
    "Now, let's define the multi-hop program in DSPy. It's going to be super simple, composed of `generate_query` and `append_notes` modules. We'll define the instructions carefully, though they are typically not necessary."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "instr1 = \"\"\"\n",
    "Given a claim and some key facts, generate a follow-up search query to find the next most essential clue towards verifying or refuting the claim. The goal ultimately is to find all documents implicated by the claim.\n",
    "\"\"\".strip()\n",
    "\n",
    "instr2 = \"\"\"\n",
    "Given a claim, some key facts, and new search results, identify any new learnings from the new search results, which will extend the key facts known so far about the whether the claim is true or false. The goal is to ultimately collect all facts that would help us find all documents implicated by the claim.\n",
    "\"\"\"\n",
    "\n",
    "\n",
    "class ResearchHop(dspy.Module):\n",
    "    def __init__(self, num_docs, num_hops):\n",
    "        self.num_docs, self.num_hops = num_docs, num_hops\n",
    "        self.generate_query = dspy.ChainOfThought(dspy.Signature(\"claim, key_facts -> followup_search_query\", instr1))\n",
    "        self.append_notes = dspy.ChainOfThought(dspy.Signature(\"claim, key_facts, new_search_results -> new_key_facts\", instr2))\n",
    "\n",
    "    def forward(self, claim: str) -> list[str]:\n",
    "        key_facts = []\n",
    "        retrieved_docs = []\n",
    "\n",
    "        for hop_idx in range(self.num_hops):\n",
    "            query = self.generate_query(claim=claim, key_facts=key_facts).followup_search_query if hop_idx else claim\n",
    "            search_results = search(query, k=self.num_docs)\n",
    "            retrieved_docs.extend(search_results)\n",
    "\n",
    "            if hop_idx == self.num_hops - 1:\n",
    "                break\n",
    "                \n",
    "            prediction = self.append_notes(claim=claim, key_facts=key_facts, new_search_results=search_results)\n",
    "            key_facts.append(prediction.new_key_facts)\n",
    "\n",
    "        return dspy.Prediction(key_facts=key_facts, retrieved_docs=retrieved_docs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define metrics for success in this task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def recall(example, pred, trace=None):\n",
    "    gold_titles = example.titles\n",
    "    retrieved_titles = [doc.split(\" | \")[0] for doc in pred.retrieved_docs]\n",
    "    return sum(x in retrieved_titles for x in set(gold_titles)) / len(gold_titles)\n",
    "\n",
    "evaluate = dspy.Evaluate(devset=devset, metric=recall, num_threads=16, display_progress=True, display_table=5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Optimize the `ResearchHop` system with `dspy.GRPO`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "program = ResearchHop(num_docs=4, num_hops=2)\n",
    "program.set_lm(local_lm)\n",
    "\n",
    "# NOTE: Training on 4 GPUs.\n",
    "train_kwargs = {\n",
    "    \"per_device_train_batch_size\": 2,\n",
    "    \"gradient_accumulation_steps\": 24/6,\n",
    "    \"temperature\": 1.0,\n",
    "    \"top_k\": -1,\n",
    "    \"top_p\": 1.0,\n",
    "    \"repetition_penalty\": 1.0,\n",
    "    \"beta\": 0.00,\n",
    "    \"learning_rate\": 1e-6,\n",
    "    \"gradient_checkpointing\": True,\n",
    "    \"bf16\": True,\n",
    "    \"lr_scheduler_type\": \"constant_with_warmup\",\n",
    "    \"loss_type\": \"dapo\",\n",
    "    \"max_steps\": 1000,\n",
    "    \"report_to\": \"wandb\",\n",
    "    \"log_completions\": True,\n",
    "    \"logging_steps\": 1,\n",
    "    \"max_prompt_length\": None,\n",
    "    \"max_completion_length\": None,\n",
    "    \"scale_rewards\": False,\n",
    "    \"max_grad_norm\": 1.0,\n",
    "    \"lora_config\": {\n",
    "        \"lora_alpha\": 16,\n",
    "        \"lora_dropout\": 0.05,\n",
    "        \"r\": 8,\n",
    "        \"target_modules\": [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"up_proj\", \"down_proj\", \"gate_proj\"],\n",
    "    },\n",
    "    \"num_training_gpus\": 3,\n",
    "    \"num_inference_gpus\": 1,\n",
    "    \"weight_decay\": 0.001,\n",
    "}\n",
    "\n",
    "compiler = ArborGRPO(\n",
    "    metric=recall,\n",
    "    num_dspy_examples_per_grpo_step=6,\n",
    "    num_rollouts_per_grpo_step=24,\n",
    "    exclude_demos=True,\n",
    "    num_train_steps=1000,\n",
    "    num_threads=16,\n",
    "    use_train_as_val=False,\n",
    "    num_steps_for_val=50,\n",
    "    train_kwargs=train_kwargs,\n",
    "    checkpoint=\"single-best\",\n",
    ")\n",
    "\n",
    "optimized_program = compiler.compile(\n",
    "    student=program,\n",
    "    trainset=trainset,\n",
    "    valset=devset,\n",
    ")\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, you can use the GRPO'ed program."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "example = devset[0]\n",
    "optimized_program(**example.inputs())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In our preliminary experiments, training about 18 hours boosts the recall (devset) from 61.8% to 66.2%. This is _typically_ worse on cost/quality basis than you'd get from running prompt optimizers dspy.MIPROv2 or dspy.SIMBA, but it's still a very solid start for online RL over arbitrary LM programs for small LMs."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "arbor-exps",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
